skip h2d and d2h copies between forward functions in gemma4-31b#20286
skip h2d and d2h copies between forward functions in gemma4-31b#20286Gasoonjia wants to merge 3 commits into
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/20286
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New Failure, 35 Pending, 1 Unrelated Failure, 2 Unclassified FailuresAs of commit ef29abd with merge base 48ff29e ( NEW FAILURE - The following job has failed:
UNCLASSIFIED FAILURES - DrCI could not classify the following jobs because the workflow did not run on the merge base. The failures may be pre-existing on trunk or introduced by this PR:
BROKEN TRUNK - The following job failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@Gasoonjia has exported this pull request. If you are a Meta employee, you can view the originating Diff in D108661628. |
This PR needs a
|
Summary: This diff updates gemma4-31b export and runtime pipeline to skip the h2d and d2h copies between prefill and decode, and between previous round next round of decode as well. Differential Revision: D108661628
…ias) Eliminate the per-decode-round token D2H->H2D round-trip: - sampler.py: emit the sampled token as int64 (was float32) so the decode method's int64 token output can be aliased directly as the next forward's int64 token input (value-preserving: argmax index; token ids < 2^24). - main.cpp: read_token reads int64; each forward's on-device output token is aliased via make_tensor_ptr and fed straight back as the next step's token input (prefill->decode and decode->decode). Only the per-round position H2D remains. Measured (int6/gguf, cuda graph OFF, p19/d128): post-load HtoD 261->132 (token H2D removed; ~= decode length); DtoH/DtoD counts unchanged (129), bytes 4B->8B (token now int64). Greedy output byte-identical to prior export.
Kill the per-decode-round position H2D (the last per-round host->device copy left after Option A): upload the full decode position array to device once (single H2D), then each step copy that step's position from the array into the fixed position input slot with an on-device D2D. Token stays aliased on device (Option A). Per-round HtoD is now 0, independent of decode length; the fixed input slot keeps it cuda-graph-safe (with cuda graph on, the D2D becomes a captured cudaMemcpyAsync on the decode stream into the same slot). Measured (int6/gguf, cuda graph OFF, p19/d128): post-load HtoD 132->5 (per-round H2D=0); DtoD 129->257 (+128 per-round pos d2d, the intended H2D->d2d trade); DtoH unchanged (129). Greedy output byte-identical to prior runs. Runner-only; reuses the int64-output export (no re-export).
aefe4cb to
ef29abd
Compare
Summary: This diff updates gemma4-31b export and runtime pipeline to skip the h2d and d2h copies between prefill and decode, and between previous round next round of decode as well.
Differential Revision: D108661628